{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "---\n",
        "sidebar_position: 4\n",
        "---\n",
        "\n",
        "# How to create a custom chat model class\n",
        "\n",
        "```{=mdx}\n",
        ":::info Prerequisites\n",
        "\n",
        "This guide assumes familiarity with the following concepts:\n",
        "\n",
        "- [Chat models](/docs/concepts/chat_models)\n",
        "\n",
        ":::\n",
        "```\n",
        "\n",
        "This notebook goes over how to create a custom chat model wrapper, in case you want to use your own chat model or a different wrapper than one that is directly supported in LangChain.\n",
        "\n",
        "There are a few required things that a chat model needs to implement after extending the [`SimpleChatModel` class](https://api.js.langchain.com/classes/langchain_core.language_models_chat_models.SimpleChatModel.html):\n",
        "\n",
        "- A `_call` method that takes in a list of messages and call options (which includes things like `stop` sequences), and returns a string.\n",
        "- A `_llmType` method that returns a string. Used for logging purposes only.\n",
        "\n",
        "You can also implement the following optional method:\n",
        "\n",
        "- A `_streamResponseChunks` method that returns an `AsyncGenerator` and yields [`ChatGenerationChunks`](https://api.js.langchain.com/classes/langchain_core.outputs.ChatGenerationChunk.html). This allows the LLM to support streaming outputs.\n",
        "\n",
        "Let's implement a very simple custom chat model that just echoes back the first `n` characters of the input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "import {\n",
        "  SimpleChatModel,\n",
        "  type BaseChatModelParams,\n",
        "} from \"@langchain/core/language_models/chat_models\";\n",
        "import { CallbackManagerForLLMRun } from \"@langchain/core/callbacks/manager\";\n",
        "import { AIMessageChunk, type BaseMessage } from \"@langchain/core/messages\";\n",
        "import { ChatGenerationChunk } from \"@langchain/core/outputs\";\n",
        "\n",
        "interface CustomChatModelInput extends BaseChatModelParams {\n",
        "  n: number;\n",
        "}\n",
        "\n",
        "class CustomChatModel extends SimpleChatModel {\n",
        "  n: number;\n",
        "\n",
        "  constructor(fields: CustomChatModelInput) {\n",
        "    super(fields);\n",
        "    this.n = fields.n;\n",
        "  }\n",
        "\n",
        "  _llmType() {\n",
        "    return \"custom\";\n",
        "  }\n",
        "\n",
        "  async _call(\n",
        "    messages: BaseMessage[],\n",
        "    options: this[\"ParsedCallOptions\"],\n",
        "    runManager?: CallbackManagerForLLMRun\n",
        "  ): Promise<string> {\n",
        "    if (!messages.length) {\n",
        "      throw new Error(\"No messages provided.\");\n",
        "    }\n",
        "    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing\n",
        "    // await subRunnable.invoke(params, runManager?.getChild());\n",
        "    if (typeof messages[0].content !== \"string\") {\n",
        "      throw new Error(\"Multimodal messages are not supported.\");\n",
        "    }\n",
        "    return messages[0].content.slice(0, this.n);\n",
        "  }\n",
        "\n",
        "  async *_streamResponseChunks(\n",
        "    messages: BaseMessage[],\n",
        "    options: this[\"ParsedCallOptions\"],\n",
        "    runManager?: CallbackManagerForLLMRun\n",
        "  ): AsyncGenerator<ChatGenerationChunk> {\n",
        "    if (!messages.length) {\n",
        "      throw new Error(\"No messages provided.\");\n",
        "    }\n",
        "    if (typeof messages[0].content !== \"string\") {\n",
        "      throw new Error(\"Multimodal messages are not supported.\");\n",
        "    }\n",
        "    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing\n",
        "    // await subRunnable.invoke(params, runManager?.getChild());\n",
        "    for (const letter of messages[0].content.slice(0, this.n)) {\n",
        "      yield new ChatGenerationChunk({\n",
        "        message: new AIMessageChunk({\n",
        "          content: letter,\n",
        "        }),\n",
        "        text: letter,\n",
        "      });\n",
        "      // Trigger the appropriate callback for new chunks\n",
        "      await runManager?.handleLLMNewToken(letter);\n",
        "    }\n",
        "  }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can now use this as any other chat model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "AIMessage {\n",
            "  lc_serializable: true,\n",
            "  lc_kwargs: {\n",
            "    content: 'I am',\n",
            "    tool_calls: [],\n",
            "    invalid_tool_calls: [],\n",
            "    additional_kwargs: {},\n",
            "    response_metadata: {}\n",
            "  },\n",
            "  lc_namespace: [ 'langchain_core', 'messages' ],\n",
            "  content: 'I am',\n",
            "  name: undefined,\n",
            "  additional_kwargs: {},\n",
            "  response_metadata: {},\n",
            "  id: undefined,\n",
            "  tool_calls: [],\n",
            "  invalid_tool_calls: [],\n",
            "  usage_metadata: undefined\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "const chatModel = new CustomChatModel({ n: 4 });\n",
        "\n",
        "await chatModel.invoke([[\"human\", \"I am an LLM\"]]);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "And support streaming:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "AIMessageChunk {\n",
            "  lc_serializable: true,\n",
            "  lc_kwargs: {\n",
            "    content: 'I',\n",
            "    tool_calls: [],\n",
            "    invalid_tool_calls: [],\n",
            "    tool_call_chunks: [],\n",
            "    additional_kwargs: {},\n",
            "    response_metadata: {}\n",
            "  },\n",
            "  lc_namespace: [ 'langchain_core', 'messages' ],\n",
            "  content: 'I',\n",
            "  name: undefined,\n",
            "  additional_kwargs: {},\n",
            "  response_metadata: {},\n",
            "  id: undefined,\n",
            "  tool_calls: [],\n",
            "  invalid_tool_calls: [],\n",
            "  tool_call_chunks: [],\n",
            "  usage_metadata: undefined\n",
            "}\n",
            "AIMessageChunk {\n",
            "  lc_serializable: true,\n",
            "  lc_kwargs: {\n",
            "    content: ' ',\n",
            "    tool_calls: [],\n",
            "    invalid_tool_calls: [],\n",
            "    tool_call_chunks: [],\n",
            "    additional_kwargs: {},\n",
            "    response_metadata: {}\n",
            "  },\n",
            "  lc_namespace: [ 'langchain_core', 'messages' ],\n",
            "  content: ' ',\n",
            "  name: undefined,\n",
            "  additional_kwargs: {},\n",
            "  response_metadata: {},\n",
            "  id: undefined,\n",
            "  tool_calls: [],\n",
            "  invalid_tool_calls: [],\n",
            "  tool_call_chunks: [],\n",
            "  usage_metadata: undefined\n",
            "}\n",
            "AIMessageChunk {\n",
            "  lc_serializable: true,\n",
            "  lc_kwargs: {\n",
            "    content: 'a',\n",
            "    tool_calls: [],\n",
            "    invalid_tool_calls: [],\n",
            "    tool_call_chunks: [],\n",
            "    additional_kwargs: {},\n",
            "    response_metadata: {}\n",
            "  },\n",
            "  lc_namespace: [ 'langchain_core', 'messages' ],\n",
            "  content: 'a',\n",
            "  name: undefined,\n",
            "  additional_kwargs: {},\n",
            "  response_metadata: {},\n",
            "  id: undefined,\n",
            "  tool_calls: [],\n",
            "  invalid_tool_calls: [],\n",
            "  tool_call_chunks: [],\n",
            "  usage_metadata: undefined\n",
            "}\n",
            "AIMessageChunk {\n",
            "  lc_serializable: true,\n",
            "  lc_kwargs: {\n",
            "    content: 'm',\n",
            "    tool_calls: [],\n",
            "    invalid_tool_calls: [],\n",
            "    tool_call_chunks: [],\n",
            "    additional_kwargs: {},\n",
            "    response_metadata: {}\n",
            "  },\n",
            "  lc_namespace: [ 'langchain_core', 'messages' ],\n",
            "  content: 'm',\n",
            "  name: undefined,\n",
            "  additional_kwargs: {},\n",
            "  response_metadata: {},\n",
            "  id: undefined,\n",
            "  tool_calls: [],\n",
            "  invalid_tool_calls: [],\n",
            "  tool_call_chunks: [],\n",
            "  usage_metadata: undefined\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "const stream = await chatModel.stream([[\"human\", \"I am an LLM\"]]);\n",
        "\n",
        "for await (const chunk of stream) {\n",
        "  console.log(chunk);\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Richer outputs\n",
        "\n",
        "If you want to take advantage of LangChain's callback system for functionality like token tracking, you can extend the [`BaseChatModel`](https://api.js.langchain.com/classes/langchain_core.language_models_chat_models.BaseChatModel.html) class and implement the lower level\n",
        "`_generate` method. It also takes a list of `BaseMessage`s as input, but requires you to construct and return a `ChatGeneration` object that permits additional metadata.\n",
        "Here's an example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "import { AIMessage, BaseMessage } from \"@langchain/core/messages\";\n",
        "import { ChatResult } from \"@langchain/core/outputs\";\n",
        "import {\n",
        "  BaseChatModel,\n",
        "  BaseChatModelCallOptions,\n",
        "  BaseChatModelParams,\n",
        "} from \"@langchain/core/language_models/chat_models\";\n",
        "import { CallbackManagerForLLMRun } from \"@langchain/core/callbacks/manager\";\n",
        "\n",
        "interface AdvancedCustomChatModelOptions\n",
        "  extends BaseChatModelCallOptions {}\n",
        "\n",
        "interface AdvancedCustomChatModelParams extends BaseChatModelParams {\n",
        "  n: number;\n",
        "}\n",
        "\n",
        "class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {\n",
        "  n: number;\n",
        "\n",
        "  static lc_name(): string {\n",
        "    return \"AdvancedCustomChatModel\";\n",
        "  }\n",
        "\n",
        "  constructor(fields: AdvancedCustomChatModelParams) {\n",
        "    super(fields);\n",
        "    this.n = fields.n;\n",
        "  }\n",
        "\n",
        "  async _generate(\n",
        "    messages: BaseMessage[],\n",
        "    options: this[\"ParsedCallOptions\"],\n",
        "    runManager?: CallbackManagerForLLMRun\n",
        "  ): Promise<ChatResult> {\n",
        "    if (!messages.length) {\n",
        "      throw new Error(\"No messages provided.\");\n",
        "    }\n",
        "    if (typeof messages[0].content !== \"string\") {\n",
        "      throw new Error(\"Multimodal messages are not supported.\");\n",
        "    }\n",
        "    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing\n",
        "    // await subRunnable.invoke(params, runManager?.getChild());\n",
        "    const content = messages[0].content.slice(0, this.n);\n",
        "    const tokenUsage = {\n",
        "      usedTokens: this.n,\n",
        "    };\n",
        "    return {\n",
        "      generations: [{ message: new AIMessage({ content }), text: content }],\n",
        "      llmOutput: { tokenUsage },\n",
        "    };\n",
        "  }\n",
        "\n",
        "  _llmType(): string {\n",
        "    return \"advanced_custom_chat_model\";\n",
        "  }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This will pass the additional returned information in callback events and in the `streamEvents method:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{\n",
            "  \"event\": \"on_chat_model_end\",\n",
            "  \"data\": {\n",
            "    \"output\": {\n",
            "      \"lc\": 1,\n",
            "      \"type\": \"constructor\",\n",
            "      \"id\": [\n",
            "        \"langchain_core\",\n",
            "        \"messages\",\n",
            "        \"AIMessage\"\n",
            "      ],\n",
            "      \"kwargs\": {\n",
            "        \"content\": \"I am\",\n",
            "        \"tool_calls\": [],\n",
            "        \"invalid_tool_calls\": [],\n",
            "        \"additional_kwargs\": {},\n",
            "        \"response_metadata\": {\n",
            "          \"tokenUsage\": {\n",
            "            \"usedTokens\": 4\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "  },\n",
            "  \"run_id\": \"11dbdef6-1b91-407e-a497-1a1ce2974788\",\n",
            "  \"name\": \"AdvancedCustomChatModel\",\n",
            "  \"tags\": [],\n",
            "  \"metadata\": {\n",
            "    \"ls_model_type\": \"chat\"\n",
            "  }\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "const chatModel = new AdvancedCustomChatModel({ n: 4 });\n",
        "\n",
        "const eventStream = await chatModel.streamEvents([[\"human\", \"I am an LLM\"]], {\n",
        "  version: \"v2\",\n",
        "});\n",
        "\n",
        "for await (const event of eventStream) {\n",
        "  if (event.event === \"on_chat_model_end\") {\n",
        "    console.log(JSON.stringify(event, null, 2));\n",
        "  }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Tracing (advanced)\n",
        "\n",
        "If you are implementing a custom chat model and want to use it with a tracing service like [LangSmith](https://smith.langchain.com/),\n",
        "you can automatically log params used for a given invocation by implementing the `invocationParams()` method on the model.\n",
        "\n",
        "This method is purely optional, but anything it returns will be logged as metadata for the trace.\n",
        "\n",
        "Here's one pattern you might use:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {},
      "outputs": [],
      "source": [
        "import { CallbackManagerForLLMRun } from \"@langchain/core/callbacks/manager\";\n",
        "import { BaseChatModel, type BaseChatModelCallOptions, type BaseChatModelParams } from \"@langchain/core/language_models/chat_models\";\n",
        "import { BaseMessage } from \"@langchain/core/messages\";\n",
        "import { ChatResult } from \"@langchain/core/outputs\";\n",
        "\n",
        "interface CustomChatModelOptions extends BaseChatModelCallOptions {\n",
        "  // Some required or optional inner args\n",
        "  tools: Record<string, any>[];\n",
        "}\n",
        "\n",
        "interface CustomChatModelParams extends BaseChatModelParams {\n",
        "  temperature: number;\n",
        "  n: number;\n",
        "}\n",
        "\n",
        "class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {\n",
        "  temperature: number;\n",
        "\n",
        "  n: number;\n",
        "\n",
        "  static lc_name(): string {\n",
        "    return \"CustomChatModel\";\n",
        "  }\n",
        "\n",
        "  constructor(fields: CustomChatModelParams) {\n",
        "    super(fields);\n",
        "    this.temperature = fields.temperature;\n",
        "    this.n = fields.n;\n",
        "  }\n",
        "\n",
        "  // Anything returned in this method will be logged as metadata in the trace.\n",
        "  // It is common to pass it any options used to invoke the function.\n",
        "  invocationParams(options?: this[\"ParsedCallOptions\"]) {\n",
        "    return {\n",
        "      tools: options?.tools,\n",
        "      n: this.n,\n",
        "    };\n",
        "  }\n",
        "\n",
        "  async _generate(\n",
        "    messages: BaseMessage[],\n",
        "    options: this[\"ParsedCallOptions\"],\n",
        "    runManager?: CallbackManagerForLLMRun\n",
        "  ): Promise<ChatResult> {\n",
        "    if (!messages.length) {\n",
        "      throw new Error(\"No messages provided.\");\n",
        "    }\n",
        "    if (typeof messages[0].content !== \"string\") {\n",
        "      throw new Error(\"Multimodal messages are not supported.\");\n",
        "    }\n",
        "    const additionalParams = this.invocationParams(options);\n",
        "    const content = await someAPIRequest(messages, additionalParams);\n",
        "    return {\n",
        "      generations: [{ message: new AIMessage({ content }), text: content }],\n",
        "      llmOutput: {},\n",
        "    };\n",
        "  }\n",
        "\n",
        "  _llmType(): string {\n",
        "    return \"advanced_custom_chat_model\";\n",
        "  }\n",
        "}"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "TypeScript",
      "language": "typescript",
      "name": "tslab"
    },
    "language_info": {
      "codemirror_mode": {
        "mode": "typescript",
        "name": "javascript",
        "typescript": true
      },
      "file_extension": ".ts",
      "mimetype": "text/typescript",
      "name": "typescript",
      "version": "3.7.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}